"""PyStan helper functions."""
import hashlib
import os

import dill  # dill needed since pickle cannot save lambda functions

import pystan


CACHE_PATH = os.path.join(os.path.expanduser('~/.cache'), 'pystan')
MODELS_CACHE_PATH = os.path.join(CACHE_PATH, 'models')
FITS_CACHE_PATH = os.path.join(CACHE_PATH, 'fits')


def make_model(model_code, **kwargs):
    """Additional arguments are passed to StanModel constructor."""
    # compile model
    model_code_blake2b = hashlib.blake2b(model_code.encode()).hexdigest()
    os.makedirs(MODELS_CACHE_PATH, exist_ok=True)
    try:
        with open(os.path.join(MODELS_CACHE_PATH, '{}.pkl'.format(model_code_blake2b)), 'rb') as f:
            return dill.load(f)
    except FileNotFoundError:
        pass
    model = pystan.StanModel(model_code=model_code, **kwargs)
    with open(os.path.join(MODELS_CACHE_PATH, '{}.pkl'.format(model_code_blake2b)), 'wb') as f:
        dill.dump(model, f)
    # now read cached model ("read your writes" principle)
    return make_model(model_code, **kwargs)


def sampling(model, data, **kwargs):
    """Do `model.sampling` but with caching."""
    hash_obj = hashlib.blake2b(dill.dumps(data))
    hash_obj.update(model.model_code.encode())  # will break across pystan versions
    hash_obj.update(dill.dumps(kwargs))
    data_model_params_hash = hash_obj.hexdigest()
    os.makedirs(FITS_CACHE_PATH, exist_ok=True)
    try:
        with open(os.path.join(FITS_CACHE_PATH, f'{data_model_params_hash}-summary.txt'), 'r') as f:
            print(f.read())
        with open(os.path.join(FITS_CACHE_PATH, f'{data_model_params_hash}.pkl'), 'rb') as f:
            return dill.load(f)
    except FileNotFoundError:
        pass
    fit = model.sampling(data=data, **kwargs)
    with open(os.path.join(FITS_CACHE_PATH, f'{data_model_params_hash}.pkl'), 'wb') as f:
        dill.dump(fit.extract(), f)
    with open(os.path.join(FITS_CACHE_PATH, f'{data_model_params_hash}-summary.txt'), 'w') as f:
        print(fit, file=f)
    # now read cached result
    return sampling(model, data, **kwargs)


def _test_make_model():
    model_code = 'parameters {real y;} model {y ~ normal(0,1);}'
    model1 = pystan.StanModel(model_code=model_code)
    model2 = make_model(model_code=model_code)
    model3 = make_model(model_code=model_code)
    assert model1.model_code == model2.model_code == model3.model_code


def _test_sampling():
    model_code = 'parameters {real y;} model {y ~ normal(0,1);}'
    model = make_model(model_code=model_code)
    draws1 = sampling(model, data={})
    assert draws1 is not None
    draws2 = sampling(model, data={})
    assert draws1['y'][0] == draws2['y'][0]


def _run_tests():
    _test_make_model()
    _test_sampling()


if __name__ == '__main__':
    print('running superficial tests')
    _run_tests()
    print('tests passed')
